
import torch
import os
import math
import torchvision.transforms as transforms
from PIL import Image
from torchvision.models import resnet50, resnet101
from image_synthesis.data.image_folder import MyImageFolder

def get_folder_name_and_images(image_dir='RESULT/dalle_gpt_imagenet_class_id_lr3e-6none_Warmup4.5e-4_plateau_PredCond_g16_e15_generate'):
    image_folder = MyImageFolder(image_dir)

    images = image_folder.folder_name_to_im_path()

    return images

def rank(image_dir='RESULT/dalle_gpt_imagenet_class_id_lr3e-6none_Warmup4.5e-4_plateau_PredCond_g16_val_e29_generate'):
    # get model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = resnet50(pretrained=True).to(device)
    
    preprocess = transform_test = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                # normalize,
            ])

    batch_size = 10

    # get data
    fn2im_path = get_folder_name_and_images(image_dir)
    for fn in fn2im_path.keys():
        im_paths = fn2im_path[fn]

        # split impaths into batches
        batch_paths = []
        num_batch = math.ceil(len(im_paths) / float(batch_size))
        for nb in range(num_batch):
            start = nb * batch_size
            end = min((nb+1)*batch_size, len(im_paths))
            batch = im_paths[start:end]
            batch_paths.append(batch)

        with torch.no_grad():
            probs = torch.zeros((0,0), device=device).view(0)
            for bi in range(len(batch_paths)):
                ims = [preprocess(Image.open(im)).unsqueeze(0).to(device) for im in batch_paths[bi]]
                ims = torch.cat(ims, dim=0)
                logits_ = model(ims)
                probs_ = logits_.softmax(dim=-1)
                cond_class_id = int(fn)
                probs_ = probs_[:, cond_class_id]

                probs = torch.cat((probs, probs_), dim=0)
        
        # get score and ranking
        probs = probs.cpu() # 1D
        probs, indices = torch.topk(probs, dim=-1, k=probs.shape[-1])

        rank_path = os.path.join(image_dir, fn, 'rank_results.txt')
        with open(rank_path, 'w') as rank_file:
            for i in range(probs.shape[-1]):
                prob = probs[i]
                idx = indices[i]
                path = im_paths[idx] 
                rank_file.write('{}, {}\n'.format(os.path.basename(path), prob))

            rank_file.close()
        print('rank results saved in {}'.format(rank_path))



if __name__ == '__main__':

    rank()

